{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "import pandas as pd\n",
    "\n",
    "col_names = ['sentiment','id','date','query_string','user','text']\n",
    "data_path = 'training.1600000.processed.noemoticon.csv'\n",
    "\n",
    "tweet_data = pd.read_csv(data_path, header=None, names=col_names, encoding=\"ISO-8859-1\").sample(frac=1) # .sample(frac=1) shuffles the data\n",
    "tweet_data = tweet_data[['sentiment', 'text']] # Disregard other columns\n",
    "print(tweet_data.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tUX3YexFy_kz"
   },
   "outputs": [],
   "source": [
    "# Preprocess function\n",
    "import re\n",
    "allowed_chars = ' AaBbCcDdEeFfGgHhIiJjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz0123456789~`!@#$%^&*()-=_+[]{}|;:\",./<>?'\n",
    "punct = '!?,.@#'\n",
    "maxlen = 280\n",
    "\n",
    "def preprocess(text):\n",
    "    return ''.join([' ' + char + ' ' if char in punct else char for char in [char for char in re.sub(r'http\\S+', 'http', text, flags=re.MULTILINE) if char in allowed_chars]])[:maxlen]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply preprocessing\n",
    "tweet_data['text'] = tweet_data['text'].apply(preprocess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Put __label__ in front of each sentiment\n",
    "tweet_data['sentiment'] = '__label__' + tweet_data['sentiment'].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "teQJsIJay_k1",
    "outputId": "5b3c28ae-05b4-4cee-e2f9-df50489c23ca"
   },
   "outputs": [],
   "source": [
    "# Save data\n",
    "import os\n",
    "\n",
    "# Create directory for saving data if it does not already exist\n",
    "data_dir = './processed-data'\n",
    "if not os.path.isdir(data_dir):\n",
    "    os.mkdir(data_dir)\n",
    "\n",
    "# Save a percentage of the data (you could also only load a fraction of the data instead)\n",
    "amount = 0.125\n",
    "\n",
    "tweet_data.iloc[0:int(len(tweet_data)*0.8*amount)].to_csv(data_dir + '/train.csv', sep='\\t', index=False, header=False)\n",
    "tweet_data.iloc[int(len(tweet_data)*0.8*amount):int(len(tweet_data)*0.9*amount)].to_csv(data_dir + '/test.csv', sep='\\t', index=False, header=False)\n",
    "tweet_data.iloc[int(len(tweet_data)*0.9*amount):int(len(tweet_data)*1.0*amount)].to_csv(data_dir + '/dev.csv', sep='\\t', index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "CRbruBmilCOt",
    "outputId": "5151d827-f90d-4ce6-f6f8-59b02be7f74f"
   },
   "outputs": [],
   "source": [
    "# Memory management\n",
    "del tweet_data\n",
    "import gc; gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 292
    },
    "colab_type": "code",
    "id": "BR5iIWRgy_k5",
    "outputId": "0e1616ab-b1c5-4007-d11d-5ff522fea812"
   },
   "outputs": [],
   "source": [
    "# Load the data into Corpus format\n",
    "from flair.data_fetcher import NLPTaskDataFetcher\n",
    "from pathlib import Path\n",
    "\n",
    "corpus = NLPTaskDataFetcher.load_classification_corpus(Path(data_dir), test_file='test.csv', dev_file='dev.csv', train_file='train.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 69
    },
    "colab_type": "code",
    "id": "_6XPFKdqy_k7",
    "outputId": "68c8caa5-0b73-4e00-bb1d-53c7d8b575d8"
   },
   "outputs": [],
   "source": [
    "# Make label dictionary\n",
    "label_dict = corpus.make_label_dictionary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 350
    },
    "colab_type": "code",
    "id": "f-nIq7k6y_k9",
    "outputId": "2ea897cc-175b-4830-894c-48227d50f46f"
   },
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "from flair.embeddings import WordEmbeddings, FlairEmbeddings\n",
    "\n",
    "word_embeddings = [WordEmbeddings('glove'),\n",
    "#                    FlairEmbeddings('news-forward'),\n",
    "#                    FlairEmbeddings('news-backward')\n",
    "                  ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WiVHs1aRy_k_"
   },
   "outputs": [],
   "source": [
    "# Initialize embeddings\n",
    "from flair.embeddings import DocumentRNNEmbeddings\n",
    "\n",
    "document_embeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "y7x4mV67BPAh"
   },
   "outputs": [],
   "source": [
    "# Create model\n",
    "from flair.models import TextClassifier\n",
    "\n",
    "classifier = TextClassifier(document_embeddings, label_dictionary=label_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TRCGx_VdA-5F"
   },
   "outputs": [],
   "source": [
    "# Create model trainer\n",
    "from flair.trainers import ModelTrainer\n",
    "\n",
    "trainer = ModelTrainer(classifier, corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 868
    },
    "colab_type": "code",
    "id": "AHgC91YoA-2J",
    "outputId": "a613d884-a079-480c-97bc-4ddd500bf9b6",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Train the model\n",
    "trainer.train('model-saves',\n",
    "              learning_rate=0.1,\n",
    "              mini_batch_size=32,\n",
    "              anneal_factor=0.5,\n",
    "              patience=8,\n",
    "              max_epochs=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ozdqd8KiA-zZ"
   },
   "outputs": [],
   "source": [
    "# Load the model and make predictions\n",
    "from flair.data import Sentence\n",
    "\n",
    "classifier = TextClassifier.load('model-saves/final-model.pt')\n",
    "\n",
    "pos_sentence = Sentence(preprocess('I love Python!'))\n",
    "neg_sentence = Sentence(preprocess('Python is the worst!'))\n",
    "\n",
    "classifier.predict(pos_sentence)\n",
    "classifier.predict(neg_sentence)\n",
    "\n",
    "print(pos_sentence.labels, neg_sentence.labels)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "name": "sentiment-analysis.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
